import math
import random
import numpy as np
# import util
import csv
import os
import torch.utils.data as Data
from torch.utils.data import Dataset
import torch
from torch import nn as nn
from torch import Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer, TransformerDecoder


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)


class transformer_model(nn.Module):
    '''
        [Transformer 2+2] + maxpool + FC, Transformer 2层编码+2层解码，双输出,带基于时序的注意力(使用解码后的)
    '''

    def __init__(self, input_size, output_size, bptt, step, trans_input_size):
        '''
        初始化评分网络
        :param input_size: 输入特征维数
        :param output_size: 输出特征维数
        :param trans_input_size: 输入transformer的特征维数
        :param use_cuda: 是否使用GPU
        '''
        super(transformer_model, self).__init__()

        # self.model_name = model_name
        d_hid = 200  # dimension of the feedforward network model in nn.TransformerEncoder
        nlayers = 2  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        nhead = 4  # number of heads in nn.MultiheadAttention
        dropout = 0.2  # dropout probability
        self.pos_encoder = PositionalEncoding(trans_input_size, dropout)
        encoder_layers = TransformerEncoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.trans_input_size = trans_input_size
        self.output_size = output_size
        decoder_layers = TransformerDecoderLayer(trans_input_size, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.pre_encoder = nn.Linear(input_size, trans_input_size)

        self.fc = nn.Linear(trans_input_size, output_size)
        self.classifier = nn.Linear(output_size, 101)

        self.softmax = nn.Softmax(dim=1)
        self.bptt = bptt
        self.time_step = step
        self.avgpool = nn.AdaptiveAvgPool1d(1)

    def forward(self, src: Tensor, src_mask: Tensor):
        '''
        推理
        :param src: 输入特征
        :param src_mask: 输入mask
        :return:
        '''
        bptt = self.bptt
        trans_input_size = self.trans_input_size
        output_size = self.output_size
        pre_input = self.pre_encoder(src) # 降维

        length = src.size(1)
        batch_size = src.size(0)
        seg_count = length // bptt
        CX = pre_input.view(bptt, batch_size*seg_count, trans_input_size)
        CX = CX.permute([1, 0, 2]).contiguous()

        # input = CX.squeeze()
        input = CX
        # Transformer 预测 单帧图像的类别，动作序列的区间
        input = input * math.sqrt(trans_input_size)
        input = self.pos_encoder(input)
        seq_trans_out = self.transformer_encoder(input, mask=None)
        seq_trans_out = self.transformer_decoder(input, seq_trans_out)
        seq_trans_out = seq_trans_out.permute([1, 0, 2]).contiguous()
        seq_fout = seq_trans_out.view(-1, trans_input_size)  # N x trans_input_size
        seq_fout = self.fc(seq_fout)

        seq_tmp = seq_fout.view(-1, self.time_step, output_size)
        seq_tmp = seq_tmp.permute([0, 2, 1])
        seq_avg = self.avgpool(seq_tmp)
        seq_avg = seq_avg.permute([0, 2, 1])  # .squeeze().contiguous()
        seq_fout = seq_fout.view(-1, output_size)
        seq_avg = seq_avg.view(batch_size,-1, output_size)

        seq_out = self.classifier(seq_fout)
        return seq_out, seq_avg


if __name__ == '__main__':
    path = "J:\\Thumos14\\thumos14\\val\\video_validation_0000051_s112_f2048.npz"
    data = np.load(path, allow_pickle=True)
    feature, fcode = data["x"], data["v"]
    # print(feature.shape)
    start = 0
    end = 768
    list = []
    for i in range(start, end):
        list.append(feature[i])
    t = torch.tensor(list)
    print(t.shape)  # torch.Size([768, 2048])
    t = t.unsqueeze(0)
    print(t.shape)

    input_size = 2048
    output_size = 32
    bptt = 48
    step = 8
    trans_input_size = 512
    res = transformer_model(input_size=input_size, output_size=output_size, bptt=bptt, step=step,
                            trans_input_size=trans_input_size)
    ret, ret1 = res.forward(t, None)
    print(ret.shape)  # torch.Size([96, 512, 1])
    ret = ret.unsqueeze(0).permute(0, 2, 1).unsqueeze(3).unsqueeze(4)
    print(ret.shape)
    print(ret1.shape)

